
import torch
import torch.nn.functional as F



def pod_spatial_loss(f_new, f_old, normalize=True):
    if normalize:
        f_new = F.normalize(f_new, p=2, dim=1)
        f_old = F.normalize(f_old, p=2, dim=1)

    f_new_H = f_new.mean(dim=3)
    f_old_H = f_old.mean(dim=3)

    f_new_W = f_new.mean(dim=2)
    f_old_W = f_old.mean(dim=2)

    loss_H = F.mse_loss(f_new_H, f_old_H)
    loss_W = F.mse_loss(f_new_W, f_old_W)
    return loss_H + loss_W
